import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import resnet34


def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # 待修改参数
    # img_path
    # weights_path
    img_path = "../../data_set/flower_data/flower_photos/roses/12240303_80d87f77a3_n.jpg"   #---待预测的图片路径
    weights_path = "./weights/model-0.pth"     #---train训练好的模型参数【务必与指定模型匹配】
    # 获取图像名(不含后缀)
    img_name = os.path.splitext(os.path.basename(img_path))[0]

    # 1. 加载图片并预处理
    data_transform = transforms.Compose(
        [transforms.Resize(256),
         transforms.CenterCrop(224),
         transforms.ToTensor(),
         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    # load image
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img0 = Image.open(img_path)
    
    # [N, C, H, W]
    img = data_transform(img0)
    # expand batch dimension
    img = torch.unsqueeze(img, dim=0)

    # 2. 创建模型并加载训练好的模型权重
    # create model
    model = resnet34(num_classes=5).to(device)

    # load model weights
    assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)
    model.load_state_dict(torch.load(weights_path, map_location=device))

    # 3. 模型预测
    # prediction
    model.eval()
    with torch.no_grad():
        # predict class
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)      # 预测概率
        predict_cla = torch.argmax(predict).numpy() # 预测概率最大对应的索引

    # read class_indict
    json_path = './class_indices.json'     #--索引标签对应文件：如果不存在，就按正常的索引号输出
    class_indict = range(len(predict))

    if os.path.exists(json_path):
        # 标签字典存在的时候，对应标签
        with open(json_path, "r") as f:
            class_indict = json.load(f)
    else:
        print("file: '{}' dose not exist.".format(json_path))

    

    print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
    
    # 在原图上增加【标题】（最有可能的预测类别和概率信息）
    plt.imshow(img0)
    plt.title(print_res)
    plt.savefig("predict_" + img_name + ".jpg")   #--centos中需要
    plt.show()

    # 依次打印各个类别的预测信息
    print("result: " + print_res, end = '\n ----- \n')
    for i in range(len(predict)):
        print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))


if __name__ == '__main__':
    main()
